{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0GTyDZmDVBml"
   },
   "source": [
    "# HyperStyle Domain Adaptation Notebook\n",
    "\n",
    "In the paper, we show that the weight offsets predicted by HyperStyle over the FFHQ domain are also applicable on fine-tuned generators such as toonify and StyleGAN-NADA. We demonstrate this idea below."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1rvG7nRoGDUY"
   },
   "source": [
    "## Prepare Environment and Download HyperStyle Code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "QRzve9Y5DucV",
    "outputId": "9e1b94ee-0b39-4af0-c2f4-2f619f09cb35"
   },
   "outputs": [],
   "source": [
    "#@title Clone HyperStyle Repo and Install Ninja { display-mode: \"form\" } \n",
    "import os\n",
    "os.chdir('/content')\n",
    "CODE_DIR = 'hyperstyle'\n",
    "\n",
    "## clone repo\n",
    "!git clone https://github.com/yuval-alaluf/hyperstyle.git $CODE_DIR\n",
    "\n",
    "## install ninja\n",
    "!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip\n",
    "!sudo unzip ninja-linux.zip -d /usr/local/bin/\n",
    "!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force\n",
    "\n",
    "os.chdir(f'./{CODE_DIR}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "DxeMz-doE_pr"
   },
   "outputs": [],
   "source": [
    "#@title Import Packages { display-mode: \"form\" }\n",
    "import time\n",
    "import sys\n",
    "import pprint\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "import torch\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "sys.path.append(\".\")\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "from notebooks.notebook_utils import Downloader, HYPERSTYLE_PATHS, W_ENCODERS_PATHS, FINETUNED_MODELS, RESTYLE_E4E_MODELS, run_alignment\n",
    "from utils.common import tensor2im\n",
    "from utils.inference_utils import run_inversion\n",
    "from utils.domain_adaptation_utils import run_domain_adaptation\n",
    "from utils.model_utils import load_model, load_generator\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gKGnZ8DmICRU"
   },
   "source": [
    "## Define Download Configuration\n",
    "Select below whether you wish to download all models using `pydrive`. Note that if you do not use `pydrive`, you may encounter a \"quota exceeded\" error from Google Drive."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "lhfYF7xGawvm"
   },
   "outputs": [],
   "source": [
    "#@title Download Configuration { display-mode: \"form\" }\n",
    "download_with_pydrive = True #@param {type:\"boolean\"}\n",
    "downloader = Downloader(code_dir=CODE_DIR, use_pydrive=download_with_pydrive)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "k9ZlP8LWhOTD"
   },
   "source": [
    "## Define Fine-tuned Generator Type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "FKbAFK7_OSDO"
   },
   "outputs": [],
   "source": [
    "#@title { display-mode: \"form\" }\n",
    "generator_type = 'toonify' #@param ['toonify', 'pixar', 'sketch', 'disney_princess']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MRQXrJ8nbin_"
   },
   "source": [
    "## Define and Download All Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ERkgQzGB90Bz"
   },
   "outputs": [],
   "source": [
    "#@title Download HyperStyle model { display-mode: \"form\" }\n",
    "hyperstyle_path = \"./pretrained_models/hyperstyle_ffhq.pt\" #@param {type:\"string\"}\n",
    "w_encoder_path = \"./pretrained_models/faces_w_encoder.pt\" #@param {type:\"string\"}\n",
    "\n",
    "\n",
    "if not os.path.exists(hyperstyle_path) or os.path.getsize(hyperstyle_path) < 1000000:\n",
    "    print(f'Downloading HyperStyle model for faces...')\n",
    "    downloader.download_file(file_id=HYPERSTYLE_PATHS[\"faces\"]['id'], file_name=HYPERSTYLE_PATHS[\"faces\"]['name'])\n",
    "    # if google drive receives too many requests, we'll reach the quota limit and be unable to download the model\n",
    "    if os.path.getsize(hyperstyle_path) < 1000000:\n",
    "        raise ValueError(\"Pretrained model was unable to be downloaded correctly!\")\n",
    "    else:\n",
    "        print('Done.')\n",
    "else:\n",
    "    print(f'HyperStyle model for faces already exists!')\n",
    "\n",
    "if not os.path.exists(w_encoder_path) or os.path.getsize(w_encoder_path) < 1000000:\n",
    "    print(f'Downloading the WEncoder model for faces...')\n",
    "    downloader.download_file(file_id=W_ENCODERS_PATHS[\"faces\"]['id'], file_name=W_ENCODERS_PATHS[\"faces\"]['name'])\n",
    "    # if google drive receives too many requests, we'll reach the quota limit and be unable to download the model\n",
    "    if os.path.getsize(w_encoder_path) < 1000000:\n",
    "        raise ValueError(\"Pretrained model was unable to be downloaded correctly!\")\n",
    "    else:\n",
    "        print('Done.')\n",
    "else:\n",
    "    print(f'WEncoder model for faces already exists!')\n",
    "\n",
    "net, opts = load_model(hyperstyle_path, update_opts={\"w_encoder_checkpoint_path\": w_encoder_path})\n",
    "print('Model successfully loaded!')\n",
    "pprint.pprint(vars(opts))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "n8W2uTB96mGR"
   },
   "outputs": [],
   "source": [
    "#@title Download Fine-tuned Generator { display-mode: \"form\" }\n",
    "generator_path = f\"./pretrained_models/{FINETUNED_MODELS[generator_type]['name']}\"\n",
    "\n",
    "if not os.path.exists(generator_path):\n",
    "    print(f'Downloading fine-tuned {generator_type} generator...')\n",
    "    downloader.download_file(file_id=FINETUNED_MODELS[generator_type][\"id\"], \n",
    "                             file_name=FINETUNED_MODELS[generator_type]['name'])\n",
    "    print('Done.')\n",
    "else:\n",
    "    print(f'Fine-tuned {generator_type} generator already exists!')\n",
    "\n",
    "transform = transforms.Compose([\n",
    "    transforms.Resize((256, 256)),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n",
    "\n",
    "fine_tuned_generator = load_generator(generator_path)\n",
    "print(f'Fine-tuned {generator_type} generator successfully loaded!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "TdhfpGpI6mGR",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#@title Download ReStyle-e4e Model { display-mode: \"form\" }\n",
    "\n",
    "restyle_e4e_path = os.path.join(\"./pretrained_models\", RESTYLE_E4E_MODELS['name'])\n",
    "\n",
    "if not os.path.exists(restyle_e4e_path):\n",
    "    print('Downloading ReStyle-e4e model...')\n",
    "    downloader.download_file(file_id=RESTYLE_E4E_MODELS[\"id\"], file_name=RESTYLE_E4E_MODELS[\"name\"])\n",
    "    print('Done.')\n",
    "else:\n",
    "    print('ReStyle-e4e model already exists!')\n",
    "\n",
    "restyle_e4e, restyle_e4e_opts = load_model(restyle_e4e_path, is_restyle_encoder=True)\n",
    "print(f'ReStyle-e4e model successfully loaded!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XW-CJsuwOSDO"
   },
   "outputs": [],
   "source": [
    "#@title Download (And Align) Image { display-mode: \"form\" }\n",
    "\n",
    "image_path = \"./notebooks/images/domain_adaptation.jpg\" #@param {type:\"string\"}\n",
    "\n",
    "input_is_aligned = False #@param {type:\"boolean\"}\n",
    "if not input_is_aligned:\n",
    "    input_image = run_alignment(image_path)\n",
    "else:\n",
    "    input_image = Image.open(image_path).convert(\"RGB\")\n",
    "\n",
    "input_image.resize((256, 256))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hMayt52Wgncp"
   },
   "source": [
    "## Run Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "o81i-MtOOSDQ",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "#@title { display-mode: \"form\" } \n",
    "img_transforms = transforms.Compose([transforms.Resize((256, 256)), \n",
    "                                     transforms.ToTensor(),\n",
    "                                     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])\n",
    "transformed_image = img_transforms(input_image)\n",
    "\n",
    "restyle_e4e_opts.n_iters_per_batch = 5\n",
    "restyle_e4e_opts.resize_outputs = False\n",
    "opts.n_iters_per_batch = 5\n",
    "opts.resize_outputs = False  # generate outputs at full resolution\n",
    "\n",
    "with torch.no_grad():\n",
    "    tic = time.time()\n",
    "    result, _ = run_domain_adaptation(transformed_image.unsqueeze(0).cuda(), \n",
    "                                      net, \n",
    "                                      opts, \n",
    "                                      fine_tuned_generator, \n",
    "                                      restyle_e4e, \n",
    "                                      restyle_e4e_opts)\n",
    "    toc = time.time()\n",
    "    print('Inference took {:.4f} seconds.'.format(toc - tic))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ezzdd3G3g7fN"
   },
   "source": [
    "## Visualize Result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cCoL6QE56mGR"
   },
   "outputs": [],
   "source": [
    "#@title { display-mode: \"form\" } \n",
    "resize_amount = (256, 256) if opts.resize_outputs else (opts.output_size, opts.output_size)\n",
    "\n",
    "final_res = tensor2im(result[0]).resize(resize_amount)\n",
    "input_im = tensor2im(transformed_image).resize(resize_amount)\n",
    "res = np.concatenate([np.array(input_im), np.array(final_res)], axis=1)\n",
    "res = Image.fromarray(res)\n",
    "res"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NuR0F8Cgg96U"
   },
   "source": [
    "## Save Result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "LPP8Rpjk6mGS"
   },
   "outputs": [],
   "source": [
    "#@title { display-mode: \"form\" } \n",
    "outputs_path = f\"./outputs/domain_adaptation/{generator_type}\"\n",
    "os.makedirs(outputs_path, exist_ok=True)\n",
    "res.save(os.path.join(outputs_path, os.path.basename(image_path)))"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [
    "gKGnZ8DmICRU",
    "k9ZlP8LWhOTD",
    "MRQXrJ8nbin_"
   ],
   "name": "domain_adaptation_playground.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}